from defense.base_defense import BaseDefense
from compressai.zoo import load_state_dict
from ELIC_network import TestModel
from torchvision import transforms
import torch

class ELICDefense(BaseDefense):
    """
    ELIC Defense class for applying ELIC to images.
    """
    def __init__(self, weights='0016', device=None, iterations=1):
        """
        Initialize the ELIC defense with a specified weights and device.
        
        Args:
            weights (str): model weights to use for HiFiC (default: '0016').
            device (torch.device or str): Device to move the model to (default: None).
        """       
        super(ELICDefense, self).__init__(device,iterations)
        try:
            self.state_dict = load_state_dict(torch.load(f'data/ELIC_{weights}_ft_3980_Plateau.pth.tar'))
        except FileNotFoundError:
            raise FileNotFoundError(f"Model weights for '{weights}' not found. Please provide valid weights in the data folder.")
        self.model = TestModel().from_state_dict(self.state_dict)
        self.model = self.model.to(self.device)
        self.model.eval()
        self.device_attributes.append('model')

    def _defense(self, x):
        """
        Apply the defense to the input tensor.
        
        Args:
            x (torch.Tensor): Input tensor of shape (N, C, H, W).
        
        Returns:
            torch.Tensor: iterations times compressed and decompressed tensor.
        """
        x = transforms.Resize((256,256))(x)
        for _ in range(self.iterations):
            x = self.model(x)['x_hat']
        x = transforms.Resize((224,224))(x)
        return x